import json
from pathlib import Path
from typing import List, Dict
from collections import defaultdict
import csv

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Circle
import pickle

# Match the font used in bar.py
plt.rcParams["font.family"] = "DejaVu Sans"

# -----------------------------------------------------------------------------
# LEGEND POSITION (adjust here if needed)
# -----------------------------------------------------------------------------
LEGEND_X = 0.5  # horizontal offset (0 = left edge, 0.5 = centre)
LEGEND_Y = -0.3  # vertical offset (negative = below x-axis)

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
DATA_DIR = Path(__file__).parent
SOCIAL_JSON = "/path/to/divergence_gpt_4o_webaes_static_final.json"  # social-agents predictions
VANILLA_JSON = "/path/to/gpt_merged_webaes_npten.json"  # vanilla ×10 predictions
HUMAN_GROUP_CSV = "/path/to/website_group_mean.csv"  # per-website group mean responses
HUMAN_INDIVIDUAL_CSV = "/path/to/ae_only_unambiguous_1000.csv"  # per-website individual responses

# Choose source for ground truth distribution: "group_means" or "individual"
TRUTH_SOURCE = "individual"

# Updated colors - more vibrant and punchy
# Updated colors - more vibrant but still light
SOCIAL_COLOR = (0.45, 0.85, 0.45, 1.0)       # brighter pastel green
VANILLA_COLOR = (0.95, 0.45, 0.45, 1.0)      # brighter pastel red
TRUTH_COLOR = (0.75, 0.55, 0.90, 1.0)        # brighter pastel violet
SOCIAL_PERSONA_COLOR = (0.40, 0.65, 0.95, 1.0) # brighter pastel blue

# Edge colors for circles - more contrast
EDGE_SOCIAL = (0.35, 0.70, 0.35, 0.9)        # lighter green edge
EDGE_VANILLA = (0.90, 0.35, 0.35, 0.9)       # lighter red edge
EDGE_SOCIAL_PERSONA = (0.25, 0.45, 0.80, 0.9) # lighter blue edge # deeper blue edge  # deep navy blue edge

# Vis-only widening of the plotted normals. This does not affect the
# overlap calculations, which use the unscaled fitted sigmas.
SIGMA_DRAW_SCALE = 1.6


# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------

def _load_json(path: Path):
    with open(path, "r", encoding="utf-8") as fh:
        return json.load(fh)


def _extract_prediction(item: dict):
    """Return a single numeric prediction if present in *item*.

    The JSON structures coming from different experiments are not consistent:
    - `overall_mean_prediction  : aggregated prediction across personas (preferred)
    - `no_persona_prediction    : single prediction without persona context

    If none of these are present/valid, returns `None.
    """
    # 1. Aggregated mean across personas (newer runs)
    pred = item.get("mean_prediction")
    if pred is not None:
        return pred

    # 2. Single no-persona prediction (older runs)
    pred = item.get("no_persona_prediction")
    if pred is not None:
        return pred

    # 3. Intentionally do not average the legacy "predictions" list here to
    #    avoid collapsing persona-level predictions. Use
    #    _extract_persona_predictions(item) where individual persona values are
    #    needed.

    # No usable prediction found
    return None


# New: robust extraction of per-persona predictions
def _extract_persona_predictions(item: dict) -> List[float]:
    """Return a list of per-persona predictions for a sample if available.

    Tries multiple schema variants seen across runs:
    - "persona_predictions":
        - dict[str, {mean_prediction: float, predictions: list[float], ...}]
        - dict[str, float]
        - list[float]
      Preference: use mean_prediction if present; else mean of predictions; else cast value to float if numeric.
    - "persona_to_prediction": dict[str, float]
    - "personas": list[dict] with a numeric "prediction" or "score" field
    - fallback to "predictions" if present (legacy; may represent personas)
    """
    # 1. Direct list
    vals = item.get("persona_predictions")
    if isinstance(vals, list) and vals:
        try:
            return [float(v) for v in vals if v is not None]
        except (TypeError, ValueError):
            pass

    # 2. Dict mapping persona -> prediction or persona -> {mean_prediction, predictions, ...}
    if isinstance(vals, dict) and vals:
        out: List[float] = []
        for v in vals.values():
            if isinstance(v, dict):
                # Prefer explicit mean_prediction
                mp = v.get("mean_prediction")
                if mp is not None:
                    try:
                        out.append(float(mp))
                        continue
                    except (TypeError, ValueError):
                        pass
                # Fallback: average of predictions list
                preds = v.get("predictions")
                if isinstance(preds, list) and preds:
                    try:
                        out.append(float(np.mean([float(x) for x in preds if x is not None])))
                        continue
                    except (TypeError, ValueError):
                        pass
            # Try to cast raw value to float
            try:
                out.append(float(v))
            except (TypeError, ValueError):
                continue
        if out:
            return out

    # 3. Alternate dict key
    alt = item.get("persona_to_prediction")
    if isinstance(alt, dict) and alt:
        out2: List[float] = []
        for v in alt.values():
            try:
                out2.append(float(v))
            except (TypeError, ValueError):
                continue
        if out2:
            return out2

    # 4. Structured list of persona records
    persons = item.get("personas") or item.get("persona_records")
    if isinstance(persons, list) and persons:
        out3: List[float] = []
        for rec in persons:
            if not isinstance(rec, dict):
                continue
            val = rec.get("prediction")
            if val is None:
                val = rec.get("score")
            if val is None:
                continue
            try:
                out3.append(float(val))
            except (TypeError, ValueError):
                continue
        if out3:
            return out3

    # 5. Fallback to legacy "predictions" list
    preds_list = item.get("predictions")
    if isinstance(preds_list, list) and preds_list:
        try:
            return [float(v) for v in preds_list if v is not None]
        except (TypeError, ValueError):
            pass

    return []


def _extract_all_predictions(item: dict) -> List[float]:
    """Return a list of replicate predictions if available in vanilla JSON.

    Tries multiple schema variants:
    - item["all_predictions"]: list[float|int]
    - item["predictions"]: list[float|int]
    - item["baseline_response"]["baseline_gpt4o"]["all_predictions"]
    Casts values to float when possible and drops non-numeric entries.
    """
    # 1) Direct keys on the item
    for key in ("all_predictions", "predictions"):
        vals = item.get(key)
        if isinstance(vals, list) and vals:
            out: List[float] = []
            for v in vals:
                try:
                    out.append(float(v))
                except (TypeError, ValueError):
                    continue
            if out:
                return out

    # 2) Nested under baseline_response → baseline_gpt4o
    baseline = item.get("baseline_response")
    if isinstance(baseline, dict):
        gpt4o = baseline.get("baseline_gpt4o")
        if isinstance(gpt4o, dict):
            vals = gpt4o.get("all_predictions") or gpt4o.get("predictions")
            if isinstance(vals, list) and vals:
                out2: List[float] = []
                for v in vals:
                    try:
                        out2.append(float(v))
                    except (TypeError, ValueError):
                        continue
                if out2:
                    return out2

    return []


def _collect_errors_social(data: List[dict]) -> List[float]:
    """Return |prediction − truth| for each image in the social-agent run."""
    errors = []
    for item in data:
        truth = item.get("mean_score")
        pred = item.get("overall_mean_prediction")
        if truth is not None and pred is not None:
            errors.append(abs(pred - truth))
    return errors


def _collect_errors_vanilla(data: List[dict]) -> List[float]:
    """Return |mean(predictions) − truth| for each image in the vanilla run."""
    errors = []
    for item in data:
        truth = item.get("mean_score")
        preds = item.get("predictions") or []
        if not preds:
            continue
        pred_mean = float(np.mean(preds))
        errors.append(abs(pred_mean - truth))
    return errors


# -----------------------------------------------------------------------------
# PREDICTION COLLECTION FUNCTIONS (UPDATED FOR SAMPLE-BY-SAMPLE)
# -----------------------------------------------------------------------------

def _collect_preds_social_normalized(data: List[dict]) -> List[float]:
    """Return a list of (prediction - ground_truth) for each sample in the social-agent run."""
    normalized_preds = []
    for item in data:
        truth = item.get("mean_score")
        pred = _extract_prediction(item)
        if truth is not None and pred is not None:
            normalized_preds.append(float(pred) - float(truth))
    return normalized_preds


def _collect_preds_vanilla_normalized(social_data: List[dict], vanilla_data: List[dict]) -> List[float]:
    """Return a list of (prediction - ground_truth) for the vanilla run.
    
    Uses ground truth from social_data since vanilla_data might not have mean_score.
    Assumes both datasets have the same samples in the same order.
    """
    normalized_preds = []
    
    # Create a mapping from social data for ground truth (in case ordering differs)
    # Assuming both datasets have some common identifier - using index for now
    min_len = min(len(social_data), len(vanilla_data))
    
    for i in range(min_len):
        social_item = social_data[i]
        vanilla_item = vanilla_data[i]
        
        truth = social_item.get("mean_score")  # Get truth from social data
        pred = _extract_prediction(vanilla_item)
        
        if truth is not None and pred is not None:
            normalized_preds.append(float(pred) - float(truth))
    
    return normalized_preds


def _collect_truth_scores_normalized(data: List[dict]) -> List[float]:
    """Return individual human ratings normalized within each sample.
    
    For each webpage, subtracts the mean rating of that webpage from each individual rating.
    This shows inter-human variability within samples.
    """
    truth_scores = []
    for item in data:
        # Get individual human scores for this sample
        individual_scores = item.get("scores", [])  # Assuming individual ratings are in 'scores'
        
        if not individual_scores:
            # Fallback: if no individual scores, try to use mean_score
            mean_score = item.get("mean_score")
            if mean_score is not None:
                # Create a single point at 0 (mean - mean = 0)
                truth_scores.append(0.0)
            continue
            
        # Calculate mean for this sample
        sample_mean = float(np.mean(individual_scores))
        
        # Subtract sample mean from each individual rating
        for score in individual_scores:
            truth_scores.append(float(score) - sample_mean)
    
    return truth_scores


def _collect_truth_scores_from_csv(csv_path: str, collapse_age_bands: bool = True) -> List[float]:
    """Return per-website group means normalized within each website.

    Source format: website, group, mean_response

    If collapse_age_bands is True, the "group" is assumed to be of the form
    "<age_band>_<other>" (e.g., "18-24_female", "18-24_male"). We will:
      - collapse within each website+age_band by averaging across subgroups
      - then compute deviations: (age_band_mean - website_mean_of_age_band_means)

    If collapse_age_bands is False, we use each unique group as-is per website.
    """

    # website -> (group_key -> list of mean_response for that group_key)
    website_to_group_values: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))

    with open(csv_path, "r", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            website = row.get("website")
            group_raw = row.get("group")
            mean_resp_str = row.get("mean_response")
            if website is None or mean_resp_str is None:
                continue
            try:
                mean_resp = float(mean_resp_str)
            except (TypeError, ValueError):
                continue

            if collapse_age_bands and group_raw:
                # Extract age band before first underscore, e.g., "18-24_female" -> "18-24"
                group_key = group_raw.split("_", 1)[0]
            else:
                group_key = group_raw or "unknown"

            website_to_group_values[website][group_key].append(mean_resp)

    # For each website, compute per-group_key mean, then center by website mean
    truth_scores: List[float] = []
    for website, group_map in website_to_group_values.items():
        if not group_map:
            continue
        # Average across subgroups (e.g., male/female) within an age band if applicable
        per_group_means: List[float] = [float(np.mean(vals)) for vals in group_map.values() if vals]
        if not per_group_means:
            continue
        website_mean = float(np.mean(per_group_means))
        for val in per_group_means:
            truth_scores.append(float(val) - website_mean)

    return truth_scores


def _collect_truth_scores_from_individual_csv(
    csv_path: str,
    website_column: str = "website",
    response_column: str = "mean_response",
) -> List[float]:
    """Return per-website individual responses normalized within each website.

    - Uses the column specified by response_column (default: "mean_response")
      as the individual's prediction/response.
    - Groups rows by website
    - For each website, computes mean(response) across all individuals
    - Appends (individual_response - website_mean) for each individual
    """
    website_to_scores: Dict[str, List[float]] = defaultdict(list)

    def _parse_float(value: str):
        try:
            return float(value)
        except (TypeError, ValueError):
            return None

    with open(csv_path, "r", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            website = row.get(website_column)
            if not website:
                continue
            val_raw = row.get(response_column)
            if val_raw in (None, "", "NA", "NaN"):
                continue
            score_val = _parse_float(val_raw)
            if score_val is None:
                continue
            website_to_scores[website].append(score_val)

    truth_scores: List[float] = []
    for website, scores in website_to_scores.items():
        if not scores:
            continue
        website_mean = float(np.mean(scores))
        for s in scores:
            truth_scores.append(float(s) - website_mean)

    return truth_scores


def _image_path_to_website_id(image_path: str) -> str:
    """Convert an image path like 'english_resized/327.png' → 'english_327'."""
    try:
        parts = image_path.strip("/").split("/")
        folder = parts[-2]
        name = parts[-1].rsplit(".", 1)[0]
        # folder: 'english_resized' → prefix 'english'
        prefix = folder.split("_")[0]
        return f"{prefix}_{name}"
    except Exception:
        return ""


def _load_individual_truth_and_means(
    csv_path: str,
    website_column: str = "website",
    response_column: str = "mean_response",
) -> tuple[List[float], Dict[str, float]]:
    """Load human responses and return (centered_scores, website_to_mean).

    Exactly matches the requested logic: for each website, compute its mean
    over all human responses, then append each individual's deviation from that
    website mean.
    """
    website_to_scores: Dict[str, List[float]] = defaultdict(list)
    with open(csv_path, "r", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            website = row.get(website_column)
            if not website:
                continue
            val_raw = row.get(response_column)
            if val_raw in (None, "", "NA", "NaN"):
                continue
            try:
                score_val = float(val_raw)
            except (TypeError, ValueError):
                continue
            website_to_scores[website].append(score_val)

    website_to_mean: Dict[str, float] = {}
    centered_scores: List[float] = []
    for website, scores in website_to_scores.items():
        if not scores:
            continue
        mean_val = float(np.mean(scores))
        website_to_mean[website] = mean_val
        for s in scores:
            centered_scores.append(float(s) - mean_val)

    return centered_scores, website_to_mean


# -----------------------------------------------------------------------------
# MAIN
# -----------------------------------------------------------------------------

def _collect_social_persona_within_sample_normalized(data: List[dict]) -> List[float]:
    """Return persona predictions centered within each sample (website).

    For each webpage, take all persona predictions and subtract their mean so
    each point represents (persona_prediction - website_persona_mean).
    """
    centered: List[float] = []
    for item in data:
        persona_preds = _extract_persona_predictions(item)
        if not persona_preds:
            continue
        sample_mean = float(np.mean(persona_preds))
        for p in persona_preds:
            centered.append(float(p) - sample_mean)
    return centered


def _collect_social_persona_minus_truth(
    data: List[dict], website_to_mean: Dict[str, float]
) -> List[float]:
    diffs: List[float] = []
    for item in data:
        image = item.get("image") or item.get("img") or ""
        website_id = _image_path_to_website_id(str(image))
        truth_mean = website_to_mean.get(website_id)
        if truth_mean is None:
            continue
        persona_preds = _extract_persona_predictions(item)
        if not persona_preds:
            continue
        for p in persona_preds:
            diffs.append(float(p) - float(truth_mean))
    return diffs


def _collect_vanilla_replicates_minus_truth(
    social_data: List[dict],
    vanilla_data: List[dict],
    website_to_mean: Dict[str, float],
) -> List[float]:
    diffs: List[float] = []
    min_len = min(len(social_data), len(vanilla_data))
    for i in range(min_len):
        image = social_data[i].get("image") or social_data[i].get("img") or ""
        website_id = _image_path_to_website_id(str(image))
        truth_mean = website_to_mean.get(website_id)
        if truth_mean is None:
            continue
        reps = _extract_all_predictions(vanilla_data[i])
        if not reps:
            continue
        for r in reps:
            diffs.append(float(r) - float(truth_mean))
    return diffs


def main():
    # Decide labels first (avoid heavy IO until needed)
    if TRUTH_SOURCE == "individual":
        truth_label_suffix = "individual"
        x_axis_label = "Deviation from Website Ground Truth Mean Score"
        social_label = "Social Agents"
        vanilla_label = "Law of Large Numbers"
    else:
        truth_label_suffix = "group means"
        x_axis_label = "Deviation from website mean score (group means)"
        social_label = "Social Agents"
        vanilla_label = "Law of Large Numbers"

    # Decide output file path based on truth source (save alongside this script)
    output_path = DATA_DIR / f"divergence_kde_2d_{truth_label_suffix.replace(' ', '_')}.png"
    cache_path = DATA_DIR / f"divergence_kde_2d_{truth_label_suffix.replace(' ', '_')}.pkl"

    # Optional cache: load precomputed series if available
    truth_scores = None
    social_series = None
    vanilla_series = None
    if cache_path.exists():
        try:
            with open(cache_path, "rb") as fh:
                cache = pickle.load(fh)
            truth_scores = cache.get("truth_scores")
            social_series = cache.get("social_series")
            vanilla_series = cache.get("vanilla_series")
        except Exception as e:
            print(f"Warning: could not load cache from {cache_path}: {e}")

    # Compute if cache was missing/incomplete
    if not truth_scores or not social_series or not vanilla_series:
        social_data = _load_json(SOCIAL_JSON)
        vanilla_data = _load_json(VANILLA_JSON)

        if TRUTH_SOURCE == "individual":
            truth_scores, website_to_mean = _load_individual_truth_and_means(
                HUMAN_INDIVIDUAL_CSV
            )
            social_series = _collect_social_persona_minus_truth(social_data, website_to_mean)
            vanilla_series = _collect_vanilla_replicates_minus_truth(
                social_data, vanilla_data, website_to_mean
            )
        else:
            social_series = _collect_preds_social_normalized(social_data)
            vanilla_series = _collect_preds_vanilla_normalized(social_data, vanilla_data)
            truth_scores = _collect_truth_scores_from_csv(HUMAN_GROUP_CSV)

    # Variances for legend
    var_social = float(np.var(social_series))
    var_vanilla = float(np.var(vanilla_series))
    var_truth = float(np.var(truth_scores))

    # ---------------------------------------------------------------------
    # KDEs (sample-by-sample normalized for predictions, raw for truth)
    # ---------------------------------------------------------------------
    try:
        from scipy.stats import gaussian_kde
    except ImportError:
        raise ImportError("scipy is required for KDE. Install via pip install scipy")

    # Parameters for labels and x-limits
    def _mean_std_safe(arr: List[float]):
        if len(arr) == 0:
            return float('nan'), float('nan')
        return float(np.mean(arr)), float(np.std(arr))

    mu_social, sigma_social = _mean_std_safe(social_series)
    mu_vanilla, sigma_vanilla = _mean_std_safe(vanilla_series)
    mu_truth, sigma_truth = _mean_std_safe(truth_scores)

    # KDE estimators (require at least 2 points)
    kde_social = gaussian_kde(social_series) if len(social_series) > 1 else None
    kde_vanilla = gaussian_kde(vanilla_series) if len(vanilla_series) > 1 else None
    kde_truth = gaussian_kde(truth_scores) if len(truth_scores) > 1 else None

    # Range for plotting - accommodate whichever series are available
    def _series_range(mu_val: float, sigma_val: float, arr: List[float]) -> float:
        if len(arr) >= 2 and not (np.isnan(mu_val) or np.isnan(sigma_val)):
            return float(abs(mu_val) + 3.0 * sigma_val)
        if len(arr) == 1:
            return float(max(1.0, abs(arr[0]) + 1.0))
        return None

    candidate_ranges = [
        _series_range(mu_social, sigma_social, social_series),
        _series_range(mu_vanilla, sigma_vanilla, vanilla_series),
        _series_range(mu_truth, sigma_truth, truth_scores),
    ]
    candidate_ranges = [r for r in candidate_ranges if r is not None and np.isfinite(r)]
    half_range = max(candidate_ranges) if candidate_ranges else 1.0
    x_min, x_max = -half_range, half_range
    x_grid = np.linspace(x_min, x_max, 512)

    # ---------------------------- 2-D KDE plot ---------------------------
    fig2d, ax2d = plt.subplots(figsize=(9, 7))

    # KDE values
    vanilla_y = kde_vanilla(x_grid) if kde_vanilla is not None else None
    social_y = kde_social(x_grid) if kde_social is not None else None
    truth_y = kde_truth(x_grid) if kde_truth is not None else None

    # Draw vanilla first (so it appears behind)
    if vanilla_y is not None:
        ax2d.fill_between(x_grid, vanilla_y, color=VANILLA_COLOR, alpha=0.35,
                          label=f"{vanilla_label} (σ={sigma_vanilla:.2f})")
        ax2d.plot(x_grid, vanilla_y, color=EDGE_VANILLA, linewidth=1)

    # Draw ground truth second
    if truth_y is not None:
        ax2d.fill_between(x_grid, truth_y, color=TRUTH_COLOR, alpha=0.35,
                          label=f"Ground Truth")
        ax2d.plot(x_grid, truth_y, color=(0.5, 0.4, 0.7, 1.0), linewidth=1)

    # Draw social agents last (so it appears on top)
    if social_y is not None:
        ax2d.fill_between(x_grid, social_y, color=SOCIAL_COLOR, alpha=0.35,
                          label=f"{social_label} (σ={sigma_social:.2f})")
        ax2d.plot(x_grid, social_y, color=EDGE_SOCIAL, linewidth=1)

    # (Removed overlap computation and circles per request)

    # Draw axes crossing at 0 to mimic the sketch
    ax2d.axvline(0, color="#666666", linewidth=1.5, alpha=0.8)
    ax2d.axhline(0, color="#666666", linewidth=1.0, alpha=0.6)

    # Improved styling
    ax2d.set_xlabel(x_axis_label, fontsize=16, color="#333333", labelpad=12)
    ax2d.set_ylabel("Density", fontsize=16, color="#333333", labelpad=12)
    ax2d.set_title("Social Agents vs. Law of Large Numbers", fontsize=22, 
                   color="#333333", pad=20)
    
    # Better tick styling
    ax2d.tick_params(axis='both', colors="#333333", labelsize=12)
    ax2d.grid(axis="y", linestyle="--", alpha=0.3, color="#cccccc")
    
    # Improved legend positioning and styling
    legend = ax2d.legend(loc="lower center", bbox_to_anchor=(LEGEND_X, LEGEND_Y), 
                        ncol=3, frameon=True, fontsize=13, borderpad=1.0,
                        fancybox=True, shadow=False, framealpha=1.0)
    # Remove gray background; keep a subtle edge
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('#cccccc')

    # Better layout adjustment
    fig2d.tight_layout()
    fig2d.subplots_adjust(bottom=0.28)  # extra space for legend
    # Ensure directory exists and save
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig2d.savefig(output_path, dpi=300, bbox_inches="tight", 
                  facecolor='white', edgecolor='none')

    # Improved console output formatting
    print(f"\n{'='*60}")
    print(f"Saved plot to: {output_path}")
    # Save cache of computed series for faster future runs
    try:
        with open(cache_path, "wb") as fh:
            pickle.dump({
                "truth_scores": truth_scores,
                "social_series": social_series,
                "vanilla_series": vanilla_series,
            }, fh)
        print(f"Saved cached series to: {cache_path}")
    except Exception as e:
        print(f"Warning: could not save cache: {e}")
    # Overlap prints removed per request

if __name__ == "__main__":
    main()